1D Clusterless Decoding
1D Clusterless Decoding¶
%reload_ext autoreload
%autoreload 2
import spyglass as nd
# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=ResourceWarning)
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import logging
FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="%d-%b-%y %H:%M:%S")
nwb_copy_file_name = "chimi20200216_new_.nwb"
UnitMarksIndicator¶
The first thing we need are the marks indicators for the clusterless decoding. See the Extract_Mark_indicators.ipynb for information on how to populate this table. We will use the special method fetch_xarray to get a labeled array of shape (n_time, n_mark_features, n_electrodes). Time will be in 2 ms bins, where it will be NaN if no spike occurs and the value of the spike features if a spike occured.
If more than one spike is in a single time bin from a single tetrode, we just average the marks. Technically this isn't ideal, we should use all the marks, but it doesn't seem to happen that often and the decodes seem robust to it.
The UnitMarksIndicator table depends on an interval from the interval list and a sampling rate.
import pandas as pd
from spyglass.decoding.clusterless import UnitMarksIndicator
marks = (
UnitMarksIndicator
& {
"nwb_file_name": nwb_copy_file_name,
"sort_interval_name": "runs_noPrePostTrialTimes raw data valid times",
"filter_parameter_set_name": "franklab_default_hippocampus",
"unit_inclusion_param_name": "all2",
"mark_param_name": "default",
"interval_list_name": "pos 1 valid times",
"sampling_rate": 500,
}
).fetch_xarray()
marks
/stelmo/nwb/analysis/chimi20200216_new_7M0E8ERPE7.nwb /stelmo/nwb/analysis/chimi20200216_new_6WW86B509M.nwb /stelmo/nwb/analysis/chimi20200216_new_TLD0MCIC5H.nwb /stelmo/nwb/analysis/chimi20200216_new_7BEQDOTX3E.nwb /stelmo/nwb/analysis/chimi20200216_new_F8QVNUMVJS.nwb /stelmo/nwb/analysis/chimi20200216_new_BVZKYWREUE.nwb /stelmo/nwb/analysis/chimi20200216_new_3HMJON557D.nwb /stelmo/nwb/analysis/chimi20200216_new_QGMZ5ESFVA.nwb /stelmo/nwb/analysis/chimi20200216_new_1KRVBBCP2N.nwb /stelmo/nwb/analysis/chimi20200216_new_9E2Z0R6TLO.nwb /stelmo/nwb/analysis/chimi20200216_new_ALRF0STB1P.nwb /stelmo/nwb/analysis/chimi20200216_new_F2TDZW8LRY.nwb /stelmo/nwb/analysis/chimi20200216_new_LTEU71Z21T.nwb /stelmo/nwb/analysis/chimi20200216_new_KT4E4LIYAI.nwb /stelmo/nwb/analysis/chimi20200216_new_KOIRLX6R6X.nwb /stelmo/nwb/analysis/chimi20200216_new_4S01EA6NVN.nwb /stelmo/nwb/analysis/chimi20200216_new_ATQO860QOB.nwb /stelmo/nwb/analysis/chimi20200216_new_H3E2HYMEJA.nwb /stelmo/nwb/analysis/chimi20200216_new_4KJ4XVBKW3.nwb /stelmo/nwb/analysis/chimi20200216_new_0V98T6HQHX.nwb /stelmo/nwb/analysis/chimi20200216_new_A5FBXFDZMD.nwb /stelmo/nwb/analysis/chimi20200216_new_A5ELOH1L7Y.nwb
<xarray.DataArray (time: 655645, marks: 4, electrodes: 22)>
array([[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
...,
[[ -99., nan, nan, ..., nan, nan, nan],
[-100., nan, nan, ..., nan, nan, nan],
[ -94., nan, nan, ..., nan, nan, nan],
[-104., nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]]])
Coordinates:
* time (time) float64 1.582e+09 1.582e+09 ... 1.582e+09 1.582e+09
* marks (marks) <U14 'amplitude_0000' ... 'amplitude_0003'
* electrodes (electrodes) int64 0 1 2 3 5 6 7 8 9 ... 15 16 17 18 19 21 22 23After you get the marks, it is important to visualize them to make sure they look right. We can use the plot_all_marks method of UnitMarksIndicator to quickly plot each mark feature against the other for each electrode.
Here it is important to look for things that look overly correlated (strong diagonal on the off-diagonal plots) and for extreme amplitudes.
UnitMarksIndicator.plot_all_marks(marks)
Position¶
After the marks look good, you'll need to load/populate the 2D position data. This comes from the IntervalPositionInfo table. Refer to the notebook 4_position_information.ipynb for more information. Note that we will need to upsample the position data (which is done here via the default_decoding parameters) to match the sampling frequency that we intend to decode in (2 ms time bins or 500 Hz sampling rate)
from spyglass.common.common_position import IntervalPositionInfo
position_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default_decoding",
}
position_info = (IntervalPositionInfo() & position_key).fetch1_dataframe()
position_info
/stelmo/nwb/analysis/chimi20200216_new_6YC9LPAR7S.nwb
| head_position_x | head_position_y | head_orientation | head_velocity_x | head_velocity_y | head_speed | |
|---|---|---|---|---|---|---|
| time | ||||||
| 1.581887e+09 | 91.051650 | 211.127050 | 2.680048 | 1.741550 | 2.301478 | 2.886139 |
| 1.581887e+09 | 91.039455 | 211.144123 | 3.003241 | 1.827555 | 2.333931 | 2.964320 |
| 1.581887e+09 | 91.027260 | 211.161196 | 3.008398 | 1.915800 | 2.366668 | 3.044898 |
| 1.581887e+09 | 91.015065 | 211.178268 | 3.012802 | 2.006286 | 2.399705 | 3.127901 |
| 1.581887e+09 | 91.002871 | 211.195341 | 3.017242 | 2.099012 | 2.433059 | 3.213352 |
| ... | ... | ... | ... | ... | ... | ... |
| 1.581888e+09 | 182.158583 | 201.299625 | -0.944304 | 0.057520 | -0.356012 | 0.360629 |
| 1.581888e+09 | 182.158583 | 201.296373 | -0.942329 | 0.053954 | -0.356343 | 0.360404 |
| 1.581888e+09 | 182.158583 | 201.293121 | -0.940357 | 0.050477 | -0.356407 | 0.359964 |
| 1.581888e+09 | 182.158583 | 201.289869 | -0.953059 | 0.047091 | -0.356212 | 0.359312 |
| 1.581888e+09 | 182.158583 | 201.286617 | -0.588081 | 0.043796 | -0.355764 | 0.358450 |
655645 rows × 6 columns
It is important to visualize the 2D position to make sure there are no weird values.
plt.figure(figsize=(7, 6))
plt.plot(position_info.head_position_x, position_info.head_position_y)
[<matplotlib.lines.Line2D at 0x7f6bfabcd280>]
Next we load the linearized position tables. Refer to the notebook 5_linearization.ipynb for more information on how to create the linear position information.
from spyglass.common.common_position import IntervalLinearizedPosition
linearization_key = {
"position_info_param_name": "default_decoding",
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"track_graph_name": "6 arm",
"linearization_param_name": "default",
}
linear_position_df = (
IntervalLinearizedPosition() & linearization_key
).fetch1_dataframe()
linear_position_df
/stelmo/nwb/analysis/chimi20200216_new_0YEJJLYTSM.nwb
| linear_position | track_segment_id | projected_x_position | projected_y_position | |
|---|---|---|---|---|
| time | ||||
| 1.581887e+09 | 412.042773 | 0 | 90.802281 | 210.677533 |
| 1.581887e+09 | 412.061718 | 0 | 90.785714 | 210.686724 |
| 1.581887e+09 | 412.080664 | 0 | 90.769147 | 210.695914 |
| 1.581887e+09 | 412.099610 | 0 | 90.752579 | 210.705105 |
| 1.581887e+09 | 412.118556 | 0 | 90.736012 | 210.714296 |
| ... | ... | ... | ... | ... |
| 1.581888e+09 | 340.325042 | 1 | 175.434739 | 212.920160 |
| 1.581888e+09 | 340.323413 | 1 | 175.433329 | 212.919345 |
| 1.581888e+09 | 340.321785 | 1 | 175.431919 | 212.918529 |
| 1.581888e+09 | 340.320156 | 1 | 175.430509 | 212.917713 |
| 1.581888e+09 | 340.318527 | 1 | 175.429100 | 212.916898 |
655645 rows × 4 columns
We should also quickly visualize the linear position in order to sanity check the values. Here we plot the 2D position projected to its corresponding 1D segment.
plt.figure(figsize=(7, 6))
plt.scatter(
linear_position_df.projected_x_position,
linear_position_df.projected_y_position,
c=linear_position_df.track_segment_id,
cmap="tab20",
s=1,
)
<matplotlib.collections.PathCollection at 0x7fb80565b5e0>
We should also plot the linearized position itself to make sure it is okay.
plt.figure(figsize=(20, 10))
plt.scatter(
linear_position_df.index,
linear_position_df.linear_position,
s=1,
c=linear_position_df.track_segment_id,
cmap="tab20",
)
<matplotlib.collections.PathCollection at 0x7fb8055de0a0>
Okay now that we've looked at the data, we should quickly verify that all our data is the same size. It may not be due to the valid intervals of the neural and position data.
position_info.shape, marks.shape, linear_position_df.shape
((655645, 6), (655645, 4, 22), (655645, 4))
We also want to make sure we have valid ephys data and valid position data for decoding. Here we only have one valid time interval, but if we had more than one, we should decode on each interval separately.
from spyglass.common.common_interval import interval_list_intersect
from spyglass.common import IntervalList
key = {}
key["interval_list_name"] = "02_r1"
key["nwb_file_name"] = nwb_copy_file_name
interval = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["interval_list_name"],
}
).fetch1("valid_times")
valid_ephys_times = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": "raw data valid times",
}
).fetch1("valid_times")
position_interval_names = (
IntervalPositionInfo
& {
"nwb_file_name": key["nwb_file_name"],
"position_info_param_name": "default_decoding",
}
).fetch("interval_list_name")
valid_pos_times = [
(
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": pos_interval_name,
}
).fetch1("valid_times")
for pos_interval_name in position_interval_names
]
intersect_interval = interval_list_intersect(
interval_list_intersect(interval, valid_ephys_times), valid_pos_times[0]
)
valid_time_slice = slice(intersect_interval[0][0], intersect_interval[0][1])
valid_time_slice
slice(1581886916.3153033, 1581888227.5987928, None)
linear_position_df = linear_position_df.loc[valid_time_slice]
marks = marks.sel(time=valid_time_slice)
position_info = position_info.loc[valid_time_slice]
position_info.shape, marks.shape, linear_position_df.shape
((655643, 6), (655643, 4, 22), (655643, 4))
Decoding¶
Okay, now having sanity checked the data, we can finally get to decoding. In the future this will be a pipeline, but for now it is manual as the table structure is still being prototyped.
In order to set the parameters, we can fetch the default parameters and modify them.
For 1D decoding, it is best to pass in the track graph and track graph parameters we used for linearization in order for the random walk to be handled properly. We can also set the amount of smoothing in the position and mark dimensions: position_std and mark_std respectively. Finally we set the block_size, which controls how many samples get processed at a time so that we don't run out of GPU memory.
from replay_trajectory_classification.environments import Environment
from spyglass.common.common_position import TrackGraph
from spyglass.decoding.clusterless import ClusterlessClassifierParameters
import pprint
parameters = (
ClusterlessClassifierParameters()
& {"classifier_param_name": "default_decoding_gpu"}
).fetch1()
track_graph = (
TrackGraph() & {"track_graph_name": "6 arm"}
).get_networkx_track_graph()
track_graph_params = (TrackGraph() & {"track_graph_name": "6 arm"}).fetch1()
parameters["classifier_params"]["environments"] = [
Environment(
track_graph=track_graph,
edge_order=track_graph_params["linear_edge_order"],
edge_spacing=track_graph_params["linear_edge_spacing"],
)
]
parameters["classifier_params"][
"clusterless_algorithm"
] = "multiunit_likelihood_integer_gpu"
parameters["classifier_params"]["clusterless_algorithm_params"] = {
"mark_std": 24.0,
"position_std": 6.0,
"block_size": 2**12,
}
pprint.pprint(parameters)
{'classifier_param_name': 'default_decoding_gpu',
'classifier_params': {'clusterless_algorithm': 'multiunit_likelihood_integer_gpu',
'clusterless_algorithm_params': {'block_size': 4096,
'mark_std': 24.0,
'position_std': 6.0},
'continuous_transition_types': [[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use_diffusion=False),
Uniform(environment_name='', environment2_name=None)],
[Uniform(environment_name='', environment2_name=None),
Uniform(environment_name='', environment2_name=None)]],
'discrete_transition_type': DiagonalDiscrete(diagonal_value=0.98),
'environments': [Environment(environment_name='', place_bin_size=2.0, track_graph=<networkx.classes.graph.Graph object at 0x7fb807560cd0>, edge_order=[(3, 6), (6, 8), (6, 9), (3, 1), (1, 2), (1, 0), (3, 4), (4, 5), (4, 7)], edge_spacing=15, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False)],
'infer_track_interior': True,
'initial_conditions_type': UniformInitialConditions(),
'observation_models': None},
'fit_params': {},
'predict_params': {'is_compute_acausal': True,
'state_names': ['Continuous', 'Uniform'],
'use_gpu': True}}
After we set the parameters, we can run the decoding. Here we are running this on the first GPU device with cp.dua.Device(0). See the Decoding_with_GPUs_on_the_GPU_cluster.ipynb notebook for more information on how to use other GPUs.
from replay_trajectory_classification import ClusterlessClassifier
import cupy as cp
with cp.cuda.Device(0):
classifier = ClusterlessClassifier(**parameters["classifier_params"])
classifier.fit(
position=linear_position_df.linear_position.values,
multiunits=marks.values,
**parameters["fit_params"],
)
results = classifier.predict(
multiunits=marks.values,
time=linear_position_df.index,
**parameters["predict_params"],
)
logging.info("Done!")
12-Sep-22 12:15:15 Fitting initial conditions... 12-Sep-22 12:15:15 Fitting continuous state transition... 12-Sep-22 12:15:15 Fitting discrete state transition 12-Sep-22 12:15:15 Fitting multiunits... 12-Sep-22 12:15:18 Estimating likelihood...
12-Sep-22 12:15:32 Estimating causal posterior... 12-Sep-22 12:18:49 Estimating acausal posterior... 12-Sep-22 12:26:34 Done!
Visualization¶
Finally, we can plot the decodes to make sure they make sense. We will use figurl to make an interactive figure. The function create_interactive_1D_decoding_figurl will return a URL that will lead you to the interactive figure. Note for this figure that you need to be running an interactive sorting view backend.
from spyglass.decoding.visualization import (
create_interactive_1D_decoding_figurl,
)
view = create_interactive_1D_decoding_figurl(
position_info,
linear_position_df,
marks,
results,
position_name="linear_position",
speed_name="head_speed",
posterior_type="acausal_posterior",
sampling_frequency=500,
view_height=800,
)
view.url(label="")
WARNING: create_position_plot is deprecated. Instead use vv.PositionPlot(...). See tests/test_position_plot.py Computing sha1 of /stelmo/nwb/.kachery-cloud/tmp_WUd2b9mq/file.npy /stelmo/nwb/.kachery-cloud/tmp_VGviN8q9 Creating segment/1/0 Creating segment/1/1 Creating segment/1/2 Creating segment/1/3 Creating segment/1/4 Creating segment/1/5 Creating segment/1/6 Creating segment/3/0 Creating segment/3/1 Creating segment/3/2 Creating segment/9/0 Creating segment/27/0 Creating segment/81/0 Creating segment/243/0 Creating segment/729/0 Creating segment/2187/0 Creating segment/6561/0 Creating segment/19683/0 Creating segment/59049/0 Creating segment/177147/0 Creating segment/531441/0 Computing sha1 of /stelmo/nwb/.kachery-cloud/tmp_VGviN8q9/live_position_pdf_plot.h5
'https://figurl.org/f?v=gs://figurl/spikesortingview-9&d=sha1://319b25199f81cc30d84f8eed471309b419c9b95d&project=lqqrbobsev&label=test'